




















import time
from requests import Response
from threading import Lock
from quant.const import Event, NAME_CHANNEL_CLOSE
from quant.markets import Channel, Handler, MarketData, OrderBook, Functions
from quant.accounts.api import Api, Spi
from quant.accounts.models import UserData
from quant.utils import logging, abstract_method, Iota, Saving
from quant.exchanges.util import create_md_id, Socket, check_book_cross, next_cid_head
from quant.exchanges.util_req import Requesting
from quant.const import UserEvent, OrderStatus, ApiType
from quant import config
from quant.exchanges.util import Socket


class ChannelBasic(Channel):
    socket: Socket
    _lock = Lock()

    def connect(self):
        self.socket.connect()

    def disconnect(self):
        self.socket.disconnect()

    def reconnect(self):
        self.socket.reconnect()

    def is_open(self):
        return self.socket.is_connected()

    def on_open(self, ws):
        event = self.event

        if event == Event.Book:
            self.subscribe_book(ws)
        elif event == Event.Trade:
            self.subscribe_trade(ws)
        elif event == Event.Ticker:
            self.subscribe_ticker(ws)
        else:
            err = 'Event {} not implemented'.format(event)
            raise NotImplementedError(err)

    def on_message(self, ws, message):
        with self._lock:
            self.markets.info_engine.push_channel_recv(self)

            now = time.time()
            message = self.unzip(message)
            self.markets.feed_raw(self.event, self.exchange, self.symbol, self.frequency, now, message)

    def on_close(self, ws):
        with self._lock:
            now = time.time()
            self.markets.feed_raw(self.event, self.exchange, self.symbol, self.frequency, now, NAME_CHANNEL_CLOSE)

    @abstract_method
    def subscribe_book(self, ws):
        pass

    @abstract_method
    def subscribe_trade(self, ws):
        pass

    @abstract_method
    def subscribe_ticker(self, ws):
        pass

    def unzip(self, message):
        return message


class HandlerBasic(Handler):
    data_id_checker = None
    data_id_checker_2 = None

    def __init__(self, event, exchange, symbol, frequency, markets):
        super().__init__(event, exchange, symbol, frequency, markets)

        if event == Event.Book:
            self.book = OrderBook()

        if event == Event.Book:
            self.process = self.process_book
        elif event == Event.Trade:
            self.process = self.process_trade
        elif event == Event.Ticker:
            self.process = self.process_ticker
        else:
            raise NotImplementedError('Handler for {} not implemented'.format(event))

    def info_welcome(self):
        logging.info('{}({}, {}, {}) welcomed'.format(self.exchange, self.event, self.symbol, self.frequency))

    def process_close(self, routing_key, recv_time):
        logging.info('{}({}, {}, {}) process close'.format(self.exchange, self.event, self.symbol, self.frequency))
        channels = self.markets.channels.get_default(self.event, self.exchange, self.symbol, self.frequency, [])
        status = [ch.is_open() for ch in channels]
        if any(status):
            return
        logging.info('{}({}, {}, {}) all closed'.format(self.exchange, self.event, self.symbol, self.frequency))

        event = self.event
        exchange = self.exchange
        symbol = self.symbol
        freq = self.frequency
        markets = self.markets
        if self.event == Event.Book:
            markets.mids.set(exchange, symbol, None)
            markets.order_books.set(exchange, symbol, None)
            self.book.clear()
        elif self.event == Event.Ticker:
            markets.mids.set(exchange, symbol, None)

        data = None
        market_data = MarketData(event, exchange, symbol, freq, routing_key, recv_time, None, create_md_id(), None, data)
        markets.routing_engine.push_data(routing_key, market_data)

    def check_book(self, book):
        result = check_book_cross(book)
        if result:
            logging.error('book({}, {}) cross!'.format(self.exchange, self.symbol))

    def push_book(self, routing_key, recv_time, server_time, server_id, book):
        event = self.event
        exchange = self.exchange
        symbol = self.symbol
        freq = self.frequency

        markets = self.markets
        market_data = MarketData(event, exchange, symbol, freq, routing_key, recv_time, server_time, create_md_id(), server_id, book)

        markets.mids.set(exchange, symbol, book.get_middle())
        markets.order_books.set(exchange, symbol, book)
        markets.routing_engine.push_data(routing_key, market_data)

    def push_trade(self, routing_key, recv_time, server_time, server_id, trade_list):
        market_data = MarketData(
            self.event,
            self.exchange,
            self.symbol,
            self.frequency,
            routing_key,
            recv_time,
            server_time,
            create_md_id(),
            server_id,
            trade_list
        )

        self.markets.routing_engine.push_data(routing_key, market_data)

    def push_ticker(self, routing_key, recv_time, server_time, server_id, ticker):
        market_data = MarketData(
            self.event,
            self.exchange,
            self.symbol,
            self.frequency,
            routing_key,
            recv_time,
            server_time,
            create_md_id(),
            server_id,
            ticker
        )

        mid = (ticker[0][0] + ticker[1][0]) / 2
        self.markets.mids.set(self.exchange, self.symbol, mid)
        self.markets.routing_engine.push_data(routing_key, market_data)

    @abstract_method
    def process_book(self, routing_key, recv_time, raw):
        pass

    @abstract_method
    def process_trade(self, routing_key, recv_time, raw):
        pass

    @abstract_method
    def process_ticker(self, routing_key, recv_time, raw):
        pass


class ApiBasic(Api):
    def __init__(self, keys, account=None):
        super().__init__(keys, account)
        self.requesting = Requesting(config.rest_workers)

        self._symbol_added = set()
        self._asset_added = set()
        self._cid_head = next_cid_head()
        self._cid_count = Iota()
        self._ignore_offset = False

    def add_symbol(self, symbol):
        if '.' in symbol:
            self._set_default_position(symbol)
        else:
            self._set_default_balance(symbol)

        asset_1, asset_0 = symbol.split('.')[0].split('/')
        self._asset_added.add(asset_0)
        self._asset_added.add(asset_1)
        self._symbol_added.add(symbol)

    def query_balance(self, symbol=None):
        return

    def query_margin(self, symbol=None):
        return

    def query_position(self, symbol=None):
        return

    def join(self):
        return self.requesting.join()

    def ignore_offset(self, ignore=False):
        self._ignore_offset = ignore

    def set_ws_mode(self, ws=True):
        err = 'Ws mode not supported'
        raise NotImplementedError(err)

    @abstract_method
    def _on_balance(self, symbol, response: Response):
        pass

    @abstract_method
    def _on_margin(self, symbol, response: Response):
        pass

    @abstract_method
    def _on_position(self, symbol, response: Response):
        pass

    @abstract_method
    def _on_open_orders(self, symbol, response: Response):
        pass

    @abstract_method
    def _on_place(self, order, response: Response):
        pass

    @abstract_method
    def _on_cancel(self, order, response: Response):
        pass

    def _iter_default_symbol(self, symbol):
        if symbol is None:
            yield from self._symbol_added
        else:
            yield symbol

    def _create_cid(self):
        result = '{}{}'.format(self._cid_head, self._cid_count.next())
        return result

    def _fail_request(self, api_name, model, response):
        info_engine = self.account.info_engine
        if int(response.status_code) == -1:
            info_engine.put_request_fail(api_name, model, response)
        else:
            info_engine.put_operation_fail(api_name, model, response)

    def _put_data(self, event, data):
        api_type = ApiType.Api
        account = self.account
        user_data = UserData(event, api_type, data)
        account.info_engine.put_user_data(user_data)
        account.routing_engine.put_user_data(user_data)

    def _set_default_balance(self, symbol):
        for asset in symbol.split('/'):
            self.account.balance.setdefault(asset, {'free': 0, 'frozen': 0})

    def _set_default_position(self, symbol):
        self.account.position.setdefault(symbol, {'long': 0, 'short': 0, 'position': 0})


class SpiBasic(Spi):
    _socket: Socket

    def __init__(self, keys, account=None):
        super().__init__(keys, account)
        self._once_connect = False
        self._symbol_added = set()
        self._asset_added = set()
        self._cid_head = next_cid_head()
        self._cid_count = Iota()

    def connect_once(self):
        if not self._once_connect:
            self._once_connect = True
            self._socket = self._create_socket()
            self._socket.connect()

    def add_symbol(self, symbol):
        asset_1, asset_0 = symbol.split('.')[0].split('/')
        self._asset_added.add(asset_0)
        self._asset_added.add(asset_1)
        self._symbol_added.add(symbol)

    def _create_cid(self):
        result = '{}{}'.format(self._cid_head, self._cid_count.next())
        return result

    @abstract_method
    def _create_socket(self):
        pass

    def _on_open(self, ws):
        pass

    def _login(self, ws):
        pass

    def _on_login(self, ws):
        pass

    def _on_close(self, ws):
        pass

    def _on_message(self, ws, message):
        pass

    def _put_data(self, event, data):
        api_type = ApiType.Spi
        account = self.account
        user_data = UserData(event, api_type, data)
        account.info_engine.put_user_data(user_data)
        account.routing_engine.put_user_data(user_data)

    def _put_deal(self, order, amount):
        data = (order, amount)
        user_data = UserData(UserEvent.Deal, ApiType.Spi, data)
        self.account.routing_engine.put_user_data(user_data)

        self.account.info_engine.put_deal(order, amount)


class ApiRouter(Api):
    agent_map: dict

    def __init__(self, keys, account=None):
        super().__init__(keys, account)
        self._symbol_added = set()
        self._create_agent()

    def add_symbol(self, symbol):
        self._get_agent(symbol).add_symbol(symbol)
        self._symbol_added.add(symbol)

    def join(self):
        for agent in self.agent_map.values():
            agent.join()

    def query_balance(self, symbol=None):
        self.agent_map['spot'].query_balance(symbol)

    def query_all_balance(self):
        self.agent_map['spot'].query_all_balance()

    def query_margin(self, symbol=None):
        for s in self._iter_symbols(symbol):
            agent = self._get_agent(s)
            agent.query_margin(s)

    def query_all_margin(self):
        for agent in self.agent_map.values():
            agent.query_all_margin()

    def query_position(self, symbol=None):
        for s in self._iter_symbols(symbol):
            agent = self._get_agent(s)
            agent.query_position(s)

    def query_all_position(self):
        for agent in self.agent_map.values():
            agent.query_all_position()

    def query_open_orders(self, symbol=None):
        for s in self._iter_symbols(symbol):
            agent = self._get_agent(s)
            agent.query_open_orders(s)

    def query_all_open_orders(self):
        for agent in self.agent_map.values():
            agent.query_all_open_orders()

    def place_order(self, order):
        agent = self._get_agent(order.symbol)
        return agent.place_order(order)

    def cancel_order(self, order):
        agent = self._get_agent(order.symbol)
        return agent.cancel_order(order)

    def modify_order(self, old, new, **options):
        agent = self._get_agent(old.symbol)
        return agent.modify_order(old, new, **options)

    def cancel_all(self):
        for agent in self.agent_map.values():
            agent.cancel_all()

    def ignore_offset(self, ignore=False):
        for agent in self.agent_map.values():
            agent.ignore_offset(ignore)

    def set_ws_mode(self, ws=True):
        for agent in self.agent_map.values():
            agent.set_ws_mode(ws)

    def _iter_symbols(self, symbol):
        if symbol is None:
            return self._symbol_added
        return [symbol]

    @abstract_method
    def _create_agent(self):
        pass

    @abstract_method
    def _get_agent(self, symbol):
        pass


class SpiRouter(Spi):
    agent_map: dict

    def __init__(self, keys, account=None):
        super().__init__(keys, account)
        self._create_agent()

    def add_symbol(self, symbol):
        agent = self._get_agent(symbol)
        agent.add_symbol(symbol)

    @abstract_method
    def _create_agent(self):
        pass

    @abstract_method
    def _get_agent(self, symbol):
        pass


class FunctionsRouter(Functions):
    agent_map: dict

    def __init__(self):
        super().__init__()
        self._create_agent()

    def get_all_precision(self):
        result = {}
        for agent in self.agent_map.values():
            this = agent.get_all_precision()
            result.update(this)
        return result

    def get_all_ticker(self):
        result = {}
        for agent in self.agent_map.values():
            this = agent.get_all_ticker()
            result.update(this)
        return result

    def get_value_factor(self, symbol):
        agent = self._get_agent(symbol)
        return agent.get_value_factor(symbol)

    def get_all_contract_size(self):
        result = {}
        for agent in self.agent_map.values():
            this = agent.get_all_contract_size()
            result.update(this)
        return result

    def get_recent_trade(self, symbol):
        agent = self._get_agent(symbol)
        return agent.get_recent_trade(symbol)

    def simulate_deal(self, order, deal_amount, mid_price, account):
        agent = self._get_agent(order.symbol)
        return agent.simulate_deal(order, deal_amount, mid_price, account)

    @abstract_method
    def _create_agent(self):
        pass

    @abstract_method
    def _get_agent(self, symbol):
        pass


# def create_msg_callback(channel):
#     feed_raw_fn = channel.markets.feed_raw
#     event = channel.event
#     exchange = channel.exchange
#     symbol = channel.symbol
#     frequency = channel.frequency
#
#     def callback(ws, msg):
#         feed_raw_fn(event, exchange, symbol, frequency, time.time(), msg)
#
#     return callback


if __name__ == '__main__':
    def try_socket():
        _host = 'wss://stream.binance.com:9443/ws/btcusdt@depth5@1000ms'
        logging.set_level(3)

        def _msg(ws, msg):
            print('msg:', msg)

        def _op(ws):
            print(ws)
            print('open')

        def _err(ws, err):
            print('err:', err)

        def _cls(ws):
            print('close')

        def _pre(ws):
            ws.url = _host

        a = Socket('', _op, _msg, _cls, _pre)
        print(a._socket)
        a.connect()

        while True:
            input(':')
            a.disconnect()
            input(':')
            a.connect()

        # while True:
        #     input(':')
        #     print(1)
        #     a.reconnect()
        #     print(2)
        #     a.reconnect()
        #     print(3)
        #     a.reconnect()
        #     print(4)

    api = ApiBasic({'api_key': 1, 'secret_key': 2}, None)

    while True:
        a = api.create_cid()
        print(a)
        input(':')

















